SDPA decode perf improvements for qwen-3.5-35B-A3B#18759
SDPA decode perf improvements for qwen-3.5-35B-A3B#18759digantdesai wants to merge 6 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18759
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 4 New Failures, 3 Unrelated FailuresAs of commit 62428be with merge base 930ecfd ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
2cb04c3 to
febc419
Compare
There was a problem hiding this comment.
Pull request overview
This PR improves ExecuTorch CUDA SDPA decode performance for the common decode case where Lq = 1 (e.g., Qwen3.5 MoE generation), by introducing a Split-K “flash-decoding” Triton path and dispatching to it at runtime.
Changes:
- Add a Split-K decode SDPA Triton kernel (
sdpa_decode_splitk) plus a reduction kernel to improve occupancy whenL_q == 1. - Update the Qwen3.5 MoE attention path to dispatch between Split-K (decode) and tiled SDPA (prefill) via
torch.cond. - Add correctness tests and a benchmark script for SDPA decode shapes; update export example shapes to avoid overly-small AOTI shape specialization.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| examples/models/qwen3_5_moe/model.py | Switch attention to Triton SDPA and add decode-time Split-K dispatch via torch.cond. |
| examples/models/qwen3_5_moe/main.cpp | Plumb a stats callback into generation and print throughput/timing breakdown. |
| examples/models/qwen3_5_moe/export.py | Use a max-length example sequence to prevent AOTI from baking in too-small intermediate buffers. |
| backends/cuda/triton/kernels/sdpa.py | Implement Split-K decode kernel + reduction and expose sdpa_decode_splitk. |
| backends/cuda/triton/kernels/init.py | Export sdpa_decode_splitk from the kernels package. |
| backends/cuda/tests/test_triton_sdpa_splitk.py | Add CUDA BF16 unit tests validating Split-K correctness vs PyTorch SDPA reference. |
| backends/cuda/benchmarks/benchmark_sdpa.py | Add a benchmark script comparing Triton SDPA/Split-K vs PyTorch SDPA backends. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @triton_op("triton::sdpa_decode_splitk", mutates_args={}) | ||
| def sdpa_decode_splitk( | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attn_mask: Optional[torch.Tensor] = None, | ||
| dropout_p: float = 0.0, | ||
| is_causal: bool = False, | ||
| scale: float = 0.0, | ||
| enable_gqa: bool = False, | ||
| ) -> torch.Tensor: | ||
| """Split-K flash-decoding SDPA for L_q=1 (decode step). | ||
|
|
||
| Signature mirrors sdpa() for drop-in use with torch.cond dispatch. | ||
| enable_gqa is accepted but ignored — GQA is handled natively via | ||
| H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1. | ||
| """ | ||
| B, H_q, L_q, D = query.shape | ||
| _, H_kv, L_kv, _ = key.shape | ||
|
|
||
| out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype) | ||
|
|
There was a problem hiding this comment.
sdpa_decode_splitk() launches kernels that assume CUDA + bfloat16 inputs (and the reduce kernel stores bfloat16 unconditionally), but unlike sdpa() it never calls _validate_sdpa_inputs(). This can lead to hard-to-debug failures or incorrect behavior if the op is called with non-bfloat16 or non-CUDA tensors. Consider reusing _validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa) at the top of sdpa_decode_splitk (and keep the existing feature/shape checks).
| HAS_MASK=HAS_MASK, | ||
| HEAD_DIM=D, | ||
| NUM_GROUPS=num_groups, | ||
| BLOCK_G=_next_power_of_2(num_groups), | ||
| ) |
There was a problem hiding this comment.
BLOCK_G is derived via _next_power_of_2(num_groups), but _next_power_of_2() is documented as clamping to [16, 256] for head-dim tiling. Using it for num_groups forces a minimum of 16 (wasted work for MHA/low-group configs) and would be incorrect if num_groups > 256 (since BLOCK_G would clamp to 256 and not cover all groups). Consider a separate next_pow2 helper for group counts (no min clamp) and/or add a clear runtime check that num_groups <= 256 if that’s an intended constraint.
| device_name = torch.cuda.get_device_name() | ||
| print() | ||
| print("=" * 100) | ||
| print(f"SDPA Benchmark Qwen3.5-35B-A4B — {scenario}") |
There was a problem hiding this comment.
The benchmark header prints "Qwen3.5-35B-A4B", but this benchmark (and the PR description) refers to the A3B variant. This looks like a typo and can confuse readers when comparing numbers; consider correcting the printed model name.
| print(f"SDPA Benchmark Qwen3.5-35B-A4B — {scenario}") | |
| print(f"SDPA Benchmark Qwen3.5-35B-A3B — {scenario}") |
|
@claude review this PR. |
|
Claude finished @digantdesai's task in 2m 21s —— View job PR Review: SDPA decode perf improvements for qwen-3.5-35B-A3B
Overall this is a well-structured PR. The Split-K flash-decoding approach is the right strategy for L_q=1 with low H_kv — the occupancy problem is real and the 1.45x end-to-end speedup is meaningful. The kernel implementation, Issues1. Missing input validation in
Call 2.
This doesn't cause incorrect results (the 3. Reduce kernel output addressing assumes B=1 — The reduce kernel uses The same pattern applies to the partial buffers — 4. Benchmark typo: "A4B" should be "A3B" — print(f"SDPA Benchmark Qwen3.5-35B-A4B — {scenario}")Should be "A3B" to match the model variant. Copilot flagged this too. Suggestions (non-blocking)5. Partial accumulator stores unscaled The split-K kernel stores the raw unnormalized accumulator 6. num_splits = min(max(triton.cdiv(L_kv, 256), 1), 128)The choice of 256 tokens per split and max 128 splits is reasonable for A100 with ~108 SMs, but this could benefit from being SM-count-aware (e.g., 7. Test tolerance — All tests use 8. The y = torch.cond(
q.shape[2] == 1,
lambda q, k, v, mask: sdpa_decode_splitk(q, k, v, attn_mask=mask),
lambda q, k, v, mask: sdpa(q, k, v, attn_mask=mask, enable_gqa=True),
[q, k, v, attn_mask],
)Note that during AOTI tracing, both branches are traced with symbolic shapes, so 9. Export shape change — Changing SummaryThe kernel implementation is correct and well-tested. The main actionable item is issue #1 (missing CUDA/bf16 validation in |
Compares ET Triton SDPA (native GQA) against PyTorch Flash/Efficient/Math backends (expanded KV) across Lk=64..16K on A100. Uses triton.testing.do_bench for timing. Standalone script, no changes to the kernel.
Add a stats_callback to generate() that prints prefill/decode rates, model load time, TTFT, and sampling time via printf, mirroring the format in extension/llm/runner/stats.h print_report. Uses printf instead of ET_LOG(Info) because the CMake target does not link executorch_no_prim_ops (which provides the PAL logger); adding that dependency pulls in the full runtime and breaks the minimal runner build.
Register `triton::sdpa_decode_splitk` as an independent op so AOTI can trace and compile it without the runtime L_kv conditional that prevents the split-K path from appearing in the standard `sdpa` op. The split-K (flash-decoding) approach partitions the KV sequence across CTAs and reduces partial softmax results in a second kernel. The benchmark script now includes the split-K column for comparison. BLOCK_G (the GQA group tile) uses _next_power_of_2_unclamped() to avoid inflating small group counts to 16. Phantom rows from over-sized tiles change register pressure and instruction scheduling, altering fp32 accumulation order enough to degrade output quality over long autoregressive sequences. Standalone kernel benchmark on H100 (Qwen3.5 MoE decode, B=1, H_q=16, H_kv=2, D=256, bf16): Lk ET Tiled (us) ET Split-K (us) Speedup 64 131.8 259.5 0.5x 512 98.9 221.5 0.4x 4096 199.9 214.4 0.9x 8192 392.2 211.3 1.9x 16384 775.3 211.8 3.7x Split-K breaks even around Lk=4096 and dominates at longer sequences where the tiled kernel's single-CTA-per-head bottleneck becomes severe.
The previous example used T=2, which caused AOTI to compile the
chunk_gated_delta_rule kernel for a single chunk (NT=1). At runtime,
prompts longer than 64 tokens (requiring NT>1 chunks) failed with
"Error resizing tensor at input 0". Using max_seq_len-1 as the
example ensures AOTI generalizes intermediate buffer sizes for the
full sequence length range.
Comparison against original export (tq4_sdpa fused kernel)
on H100 (Qwen3.5-35B-A3B, HQQ-INT4, max_seq_len=4096, 5 runs median):
Original (tq4_sdpa) Baseline (Triton SDPA)
Decode tok/s 68.4 61.7
Prefill tok/s 275.7 378.2
Baseline prefill is 1.37x faster; decode is 0.90x (tq4_sdpa's fused
decode kernel is faster than the tiled Triton SDPA at L_q=1). The
split-K commit addresses the decode gap.
Runtime dispatch via torch.cond in FullAttention: split-K flash-decoding
for decode (L_q==1) and standard tiled SDPA for prefill (L_q>1). Guard
sdpa_decode_splitk validation behind isinstance(L_q, int) so AOTI tracing
with symbolic shapes doesn't trip the L_q==1 check.
Align sdpa_decode_splitk signature with sdpa (dropout_p, is_causal,
enable_gqa) for drop-in use with torch.cond; unsupported args fail
with clear messages.
End-to-end on H100 (Qwen3.5-35B-A3B, HQQ-INT4, max_seq_len=4096,
1024 decode tokens, prompt="Hi", temperature=0, 5 runs median):
Baseline (tiled) Split-K Speedup
Decode tok/s 61.7 89.9 1.46x
Prefill tok/s 378.2 378.2 1.00x
nsys GPU time 13853 ms 8674 ms 1.60x
SDPA kernel 5370 ms (38.8%) 209 ms (2.4%) 25.7x
ebe61e8 to
5d3b620
Compare
Import ordering, line-length wrapping, and missing blank lines flagged by CI lintrunner.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| printf( | ||
| "\n\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64, | ||
| stats.num_prompt_tokens, | ||
| stats.num_generated_tokens); |
| for name, label, _ in backends: | ||
| if name == ref_name or outputs[name] is None: | ||
| continue | ||
| err = _max_abs_error(outputs[name], ref_out) | ||
| assert err < 1e-2, ( | ||
| f"Output mismatch for {_shape_label(shape)}: " | ||
| f"{label} vs {BACKENDS[ref_name][0]}, " | ||
| f"max abs error {err:.3e} >= 1e-2" | ||
| ) |
| out = self.splitk(q, k, v, attn_mask=mask) | ||
|
|
||
| self.assertFalse(torch.isnan(out).any(), "All-masked should not NaN") | ||
| self.assertFalse(torch.isinf(out).any(), "All-masked should not Inf") |
Performance Improvements for SDPA
Improves SDPA performance for decode sequences where$L_q = 1$ .
Benchmark: qwen3.5-35B-A3B
generate=1024tokens, median of 3 runs on A100.(~25x speedup at the SDPA op level, for ~10.2K = 1024 tokens x 10 layers, calls we saw 5.3sec to 209ms speedup)
Implementation Details
_sdpa_fwd_kernel_m64).